import pandas as pd
import plotly.express as px
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as snsTest for Animation of EnKF States
Imports
sys.path.append('../dust/Projects/ABM_DA/stationsim/')# from stationsim_gcs_model import Model
# from ensemble_kalman_filter import EnsembleKalmanFilter, EnsembleKalmanFilterType, ActiveAgentNormaliserfrom stationsim_gcs_model import Model
from ensemble_kalman_filter import EnsembleKalmanFilter, EnsembleKalmanFilterType%matplotlib inlineFunctions
def __make_exit_observation_operator(population_size):
a = np.identity(2 * population_size)
b = np.zeros(shape=(2 * population_size, population_size))
return np.hstack((a, b))
def __make_observation_operator(population_size, mode):
if mode == EnsembleKalmanFilterType.STATE:
return np.identity(2 * population_size)
elif mode == EnsembleKalmanFilterType.DUAL_EXIT:
return __make_exit_observation_operator(population_size)
else:
raise ValueError(f'Unexpected filter mode: {mode}')
def __make_state_vector_length(population_size, mode):
if mode == EnsembleKalmanFilterType.STATE:
return 2 * population_size
elif mode == EnsembleKalmanFilterType.DUAL_EXIT:
return 3 * population_size
else:
raise ValueError(f'Unexpected filter mode: {mode}')
def run_enkf(filter_params, model_params, normaliser, station, pickle_path, filter_id):
# Set up filter params
filter_params['error_normalisation'] = normaliser
model_params['station'] = station
enkf = EnsembleKalmanFilter(Model, filter_params, model_params,
filtering=True, benchmarking=True)
while enkf.active:
enkf.step()
norm = normaliser.name if normaliser is not None else 'default'
mt = station if station is not None else 'toy'
s = pickle_path + f'filter_{mt}_{norm}_{filter_id}.pkl'
with open(s, 'wb') as f:
pickle.dump(enkf, f)def get_rows(xs, ys, model_type, id_prefix, output):
rows = list()
for i in range(len(xs)):
row = output.copy()
row['model_type'] = model_type
row['agent_id'] = f'{id_prefix}_agent_{i}'
row['x'] = xs[i]
row['y'] = ys[i]
rows.append(row)
return rowsConstants
model_width = 740
model_height = 700Run EnKF
ensemble_size = 20
pop_size = 20
assimilation_period = 20
obs_noise_std = 1.0
mode = EnsembleKalmanFilterType.STATE
its = 200
model_params = {'pop_total': pop_size,
'do_print': False}
# Set up filter parameters
observation_operator = __make_observation_operator(pop_size, mode)
state_vec_length = __make_state_vector_length(pop_size, mode)
data_mode = EnsembleKalmanFilterType.STATE
data_vec_length = __make_state_vector_length(pop_size, data_mode)
filter_params = {'max_iterations': its,
'assimilation_period': assimilation_period,
'ensemble_size': ensemble_size,
'population_size': pop_size,
'vanilla_ensemble_size': ensemble_size,
'state_vector_length': state_vec_length,
'data_vector_length': data_vec_length,
'mode': mode,
'H': observation_operator,
'R_vector': obs_noise_std * np.ones(data_vec_length),
'keep_results': True,
'run_vanilla': True,
'vis': False}# Set up filter params
# filter_params['error_normalisation'] = ActiveAgentNormaliser.BASE
model_params['station'] = 'Grand_Central'
enkf = EnsembleKalmanFilter(Model, filter_params, model_params,
filtering=True, benchmarking=True)
while enkf.active:
enkf.step()../dust/Projects/ABM_DA/stationsim/ensemble_kalman_filter.py:215: RuntimeWarning: EnKF received unexpected attribute (vanilla_ensemble_size).
warns.warn(w, RuntimeWarning)
Running Ensemble Kalman Filter...
max_iterations: 200
ensemble_size: 20
assimilation_period: 20
pop_size: 20
filter_type: EnsembleKalmanFilterType.STATE
inclusion_type: None
ensemble_errors: False
Process results
results = list()
for result in enkf.results:
output = {'time': result['time']}
# Observations
xs, ys = enkf.separate_coords(result['observation'])
obs_results = get_rows(xs, ys, 'observation', 'observation', output)
results.extend(obs_results)
# Ground truth model
xs, ys = enkf.separate_coords(result['ground_truth'])
base_results = get_rows(xs, ys, 'ground_truth', 'ground_truth', output)
results.extend(base_results)
# Benchmark model
xs, ys = enkf.separate_coords(result['baseline'])
base_results = get_rows(xs, ys, 'baseline', 'baseline', output)
results.extend(base_results)
# Prior ensemble mean
xs, ys = enkf.separate_coords(result['prior'])
prior_state_mean_results = get_rows(xs, ys, 'prior','prior', output)
results.extend(prior_state_mean_results)
# Posterior ensemble mean
xs, ys = enkf.separate_coords(result['posterior'])
posterior_state_mean_results = get_rows(xs, ys, 'posterior','posterior', output)
results.extend(posterior_state_mean_results)
# Prior ensemble members
for j in range(enkf.ensemble_size):
state_str = f'prior_{j}'
state_output = output.copy()
xs, ys = enkf.separate_coords(result[state_str])
state_member_results = get_rows(xs, ys, 'prior_ensemble_member', f'prior_ensemble_member_{j}', state_output)
results.extend(state_member_results)
# Posterior ensemble members
for j in range(enkf.ensemble_size):
state_str = f'posterior_{j}'
state_output = output.copy()
xs, ys = enkf.separate_coords(result[state_str])
state_member_results = get_rows(xs, ys, 'posterior_ensemble_member', f'posterior_ensemble_member_{j}', state_output)
results.extend(state_member_results)
# Destinations
xs, ys = enkf.separate_coords(result['destination'])
destination_results = get_rows(xs, ys, 'destination','destination', output)
results.extend(destination_results)
# Origins
xs, ys = enkf.separate_coords(result['origin'])
origin_results = get_rows(xs, ys, 'origin','origin', output)
results.extend(origin_results)results = pd.DataFrame(results)
results.head()| time | model_type | agent_id | x | y | |
|---|---|---|---|---|---|
| 0 | 20 | observation | observation_agent_0 | 179.201482 | 679.897110 |
| 1 | 20 | observation | observation_agent_1 | 580.208471 | 691.454196 |
| 2 | 20 | observation | observation_agent_2 | 0.555967 | -2.110513 |
| 3 | 20 | observation | observation_agent_3 | 0.550657 | -0.114622 |
| 4 | 20 | observation | observation_agent_4 | -1.027456 | -0.549186 |
def get_agent_number(row):
agent_id = row["agent_id"]
return int(agent_id.split("_")[-1])test_row = {"time": 20, "model_type": "observation",
"agent_id": "observation_agent_0",
"x": 15, "y": 25}
print(get_agent_number(test_row))0
results["agent_number"] = results.apply(get_agent_number, axis=1)if "s" in list(results):
results.drop(["s"], axis=1, inplace=True)results["s"] = 10results.head()| time | model_type | agent_id | x | y | agent_number | s | |
|---|---|---|---|---|---|---|---|
| 0 | 20 | observation | observation_agent_0 | 179.201482 | 679.897110 | 0 | 10 |
| 1 | 20 | observation | observation_agent_1 | 580.208471 | 691.454196 | 1 | 10 |
| 2 | 20 | observation | observation_agent_2 | 0.555967 | -2.110513 | 2 | 10 |
| 3 | 20 | observation | observation_agent_3 | 0.550657 | -0.114622 | 3 | 10 |
| 4 | 20 | observation | observation_agent_4 | -1.027456 | -0.549186 | 4 | 10 |
results.to_csv("./animation_results.csv", index=False)results = pd.read_csv("./animation_results.csv")Create animated scatter
clock_x, clock_y = 370, 275
clock_size = 56
x_l, x_h = clock_x - (clock_size/2), clock_x + (clock_size/2)
y_l, y_h = clock_y - (clock_size/2), clock_y + (clock_size/2)f = px.scatter(results, x='x', y='y',
animation_frame='time', animation_group='agent_id',
color='model_type',
hover_name='agent_id',
range_x=[0, model_width], range_y=[0, model_height],
width=1.25*model_width, height=model_height)
f.add_shape(type="circle",
xref="x", yref="y",
fillcolor="black",
x0=x_l, y0=y_l, x1=x_h, y1=y_h,
line_color="black",
)
f